# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use

import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F


class BaseNet(nn.Module):
    """ Takes a list of images as input, and returns for each image:
        - a pixelwise descriptor
        - a pixelwise confidence
    """

    def softmax(self, ux):
        if ux.shape[1] == 1:
            x = F.softplus(ux)
            return x / (1 + x)  # for sure in [0,1], much less plateaus than softmax
        elif ux.shape[1] == 2:
            return F.softmax(ux, dim=1)[:, 1:2]

    def normalize(self, x, ureliability, urepeatability):
        return dict(descriptors=F.normalize(x, p=2, dim=1),
                    repeatability=self.softmax(urepeatability),
                    reliability=self.softmax(ureliability))

    def forward_one(self, x):
        raise NotImplementedError()

    def forward(self, imgs, **kw):
        res = [self.forward_one(img) for img in imgs]
        # merge all dictionaries into one
        res = {k: [r[k] for r in res if k in r] for k in {k for r in res for k in r}}
        return dict(res, imgs=imgs, **kw)


class PatchNet(BaseNet):
    """ Helper class to construct a fully-convolutional network that
        extract a l2-normalized patch descriptor.
    """

    def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False):
        BaseNet.__init__(self)
        self.inchan = inchan
        self.curchan = inchan
        self.dilated = dilated
        self.dilation = dilation
        self.bn = bn
        self.bn_affine = bn_affine
        self.ops = nn.ModuleList([])

    def _make_bn(self, outd):
        return nn.BatchNorm2d(outd, affine=self.bn_affine)

    def _add_conv(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True, k_pool=1, pool_type='max'):
        # as in the original implementation, dilation is applied at the end of layer, so it will have impact only from next layer
        d = self.dilation * dilation
        if self.dilated:
            conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=1)
            self.dilation *= stride
        else:
            conv_params = dict(padding=((k - 1) * d) // 2, dilation=d, stride=stride)
        self.ops.append(nn.Conv2d(self.curchan, outd, kernel_size=k, **conv_params))
        if bn and self.bn: self.ops.append(self._make_bn(outd))
        if relu: self.ops.append(nn.ReLU(inplace=True))
        self.curchan = outd

        if k_pool > 1:
            if pool_type == 'avg':
                self.ops.append(torch.nn.AvgPool2d(kernel_size=k_pool))
            elif pool_type == 'max':
                self.ops.append(torch.nn.MaxPool2d(kernel_size=k_pool))
            else:
                print(f"Error, unknown pooling type {pool_type}...")

    def forward_one(self, x):
        assert self.ops, "You need to add convolutions first"
        for n, op in enumerate(self.ops):
            print('x: ', x.shape, op)
            x = op(x)
        return self.normalize(x)


class L2_Net(PatchNet):
    """ Compute a 128D descriptor for all overlapping 32x32 patches.
        From the L2Net paper (CVPR'17).
    """

    def __init__(self, dim=128, **kw):
        PatchNet.__init__(self, **kw)
        add_conv = lambda n, **kw: self._add_conv((n * dim) // 128, **kw)
        add_conv(32)
        add_conv(32)
        add_conv(64, stride=2)
        add_conv(64)
        add_conv(128, stride=2)
        add_conv(128)
        add_conv(128, k=7, stride=8, bn=False, relu=False)
        self.out_dim = dim


class Quad_L2Net(PatchNet):
    """ Same than L2_Net, but replace the final 8x8 conv by 3 successive 2x2 convs.
    """

    def __init__(self, dim=32, mchan=4, relu22=False, **kw):
        PatchNet.__init__(self, **kw)
        self._add_conv(8 * mchan)
        self._add_conv(8 * mchan)
        self._add_conv(16 * mchan, stride=2)
        self._add_conv(16 * mchan)
        self._add_conv(32 * mchan, stride=2)
        self._add_conv(32 * mchan)
        # replace last 8x8 convolution with 3 2x2 convolutions
        self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
        self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
        self._add_conv(dim, k=2, stride=2, bn=False, relu=False)
        self.out_dim = dim


class Quad_L2Net_ConfCFS(Quad_L2Net):
    """ Same than Quad_L2Net, with 2 confidence maps for repeatability and reliability.
    """

    def __init__(self, **kw):
        Quad_L2Net.__init__(self, **kw)
        # reliability classifier
        self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1)
        # repeatability classifier: for some reasons it's a softplus, not a softmax!
        # Why? I guess it's a mistake that was left unnoticed in the code for a long time...
        self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1)

    def forward_one(self, x):
        assert self.ops, "You need to add convolutions first"
        for op in self.ops:
            x = op(x)
        # compute the confidence maps
        ureliability = self.clf(x ** 2)
        urepeatability = self.sal(x ** 2)
        return self.normalize(x, ureliability, urepeatability)


class Fast_Quad_L2Net(PatchNet):
    """ Faster version of Quad l2 net, replacing one dilated conv with one pooling to diminish image resolution thus increase inference time
    Dilation  factors and pooling:
        1,1,1, pool2, 1,1, 2,2, 4, 8, upsample2
    """

    def __init__(self, dim=128, mchan=4, relu22=False, downsample_factor=2, **kw):
        PatchNet.__init__(self, **kw)
        self._add_conv(8 * mchan)
        self._add_conv(8 * mchan)
        self._add_conv(16 * mchan, k_pool=downsample_factor)  # added avg pooling to decrease img resolution
        self._add_conv(16 * mchan)
        self._add_conv(32 * mchan, stride=2)
        self._add_conv(32 * mchan)

        # replace last 8x8 convolution with 3 2x2 convolutions
        self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
        self._add_conv(32 * mchan, k=2, stride=2, relu=relu22)
        self._add_conv(dim, k=2, stride=2, bn=False, relu=False)

        # Go back to initial image resolution with upsampling
        self.ops.append(torch.nn.Upsample(scale_factor=downsample_factor, mode='bilinear', align_corners=False))

        self.out_dim = dim


class Fast_Quad_L2Net_ConfCFS(Fast_Quad_L2Net):
    """ Fast r2d2 architecture
    """

    def __init__(self, **kw):
        Fast_Quad_L2Net.__init__(self, **kw)
        # reliability classifier
        self.clf = nn.Conv2d(self.out_dim, 2, kernel_size=1)

        # repeatability classifier: for some reasons it's a softplus, not a softmax!
        # Why? I guess it's a mistake that was left unnoticed in the code for a long time...
        self.sal = nn.Conv2d(self.out_dim, 1, kernel_size=1)

    def forward_one(self, x):
        assert self.ops, "You need to add convolutions first"
        for op in self.ops:
            x = op(x)
        # compute the confidence maps
        ureliability = self.clf(x ** 2)
        urepeatability = self.sal(x ** 2)
        return self.normalize(x, ureliability, urepeatability)


class FastL2Net(nn.Module):
    def __init__(self, out_dim=128):
        super(FastL2Net, self).__init__()


def conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_bn=False, dilation=1):
    if not use_bn:
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                      padding=1, dilation=dilation),
            nn.ReLU(),
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                      padding=padding, dilation=dilation),
            nn.BatchNorm2d(out_channels, affine=False),
            nn.ReLU(),
        )


if __name__ == '__main__':
    import time

    img = torch.ones((1, 3, 1024, 1024)).cuda()
    # img = torch.ones((1, 3, 360, 640)).cuda()
    down_scale = 4

    # net = SPD2Net32(outdim=128, use_bn=True).cuda()
    # net = SPD2Net32(outdim=32, use_bn=True).cuda()
    # net = L2NetV2(outdim=128, use_bn=True).cuda().eval()
    # net = L2NetV3(outdim=128, use_bn=True).cuda().eval()
    # net = Quad_L2Net_ConfCFS().cuda().eval()
    net = Fast_Quad_L2Net_ConfCFS().cuda().eval()
    print(net)

    total_time = 0
    # net = SPD2NetSPP(outdim=128, use_bn=True).cuda()
    # net = SPD2Net32V3(outdim=128, use_bn=True).cuda().eval()
    # net = SPD2Net32V4(outdim=128, use_bn=True, down_scale=down_scale).cuda().eval()
    # net = SPD2Net32V7(outdim=128, use_bn=True).cuda().eval()
    # net = SPD2Net32V4S(outdim=128, use_bn=True, down_scale=down_scale).cuda().eval()
    # net = SPD2Net32V6(outdim=128, use_bn=True).cuda().eval()

    # norm_conv = nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, padding=1).cuda()
    # sep_conv = SeparableConv2d(in_channels=1, out_channels=128, kernel_size=3, padding=1).cuda()

    for i in range(1000):
        start_time = time.time()
        # out = norm_conv(img)
        # out = sep_conv(img)
        with torch.no_grad():
            out = net.forward_one(img)
        # exit(0)
        # print(score.shape, desc.shape)
        # exit(0)

        total_time = total_time + time.time() - start_time

    print("mean time: ", total_time / 1000)

    """
    480-640
    L2NetV2: mean time:  0.06843652939796448
    L2NetV3: mean time:  0.05640199899673462
    """
